fix(mlx): max_grad_value default off, honor user max_grad_norm#663
fix(mlx): max_grad_value default off, honor user max_grad_norm#663danielhanchen wants to merge 8 commits into
Conversation
PR #634 set MLXTrainingConfig.max_grad_value = 1.0 (later 5.0) and at config-resolution time silently zeroed out a user-supplied max_grad_norm. The elementwise clip rotates the gradient per leaf, which is mathematically different from clip_grad_norm and not what HF/TRL users opt into when they pass max_grad_norm=1.0. Same dataset, same seed, same LR converges to a different basin on MLX than on CUDA; greedy generation collapses to gibberish even though loss descends. Make max_grad_value opt-in only: * Default None (off). User-supplied max_grad_norm is honored by default, matching HF/TRL semantics on CUDA. * Explicit float > 0 keeps the existing low-memory clip path AND the existing "ignoring max_grad_norm" notice when both are set. Add a config-level regression test pinning the default to None. Refs: #662, unslothai/unsloth#5498.
The MLX trainer's silent override of max_grad_norm by max_grad_value is being fixed upstream in unsloth-zoo #663. Once that lands, the smoke test's max_grad_norm=1.0 is the only clip in effect by default, matching transformers.SFTTrainer on CUDA, and the EXPECT_IN_OUTPUT assertion becomes a proper HF/CUDA parity gate. Add a comment that explains what the assertion is really protecting. Refs: unslothai/unsloth-zoo#662, unslothai/unsloth-zoo#663.
There was a problem hiding this comment.
Code Review
This pull request updates the MLXTrainingConfig to set the default max_grad_value to None, making the elementwise clipping path opt-in. This change ensures that user-provided max_grad_norm values are respected by default, aligning with Hugging Face and TRL semantics. Corresponding logic in the trainer was updated to handle the None default, and a new test case was added to verify the fix. I have no feedback to provide.
PR #634 silently flipped MLX AdamW's bias_correction from the historical MLX default of False to True (matching torch.optim.AdamW). For real multi-epoch fine-tunes the two converge identically after ~10-20 warmup steps, but for short memorization runs the difference is large: bias_correction=True shrinks the step-1 effective update by ~3x. Empirical bisection on a Mac M1 CI runner (probes 12 + 14 of the mlx-parity-probes workflow): * pre-#634 trainer (bias_correction=False), 7 steps: loss 10.55 -> 5.04 (bouncy), generates "Unsloth! ..." * HEAD + PR #663 only (bias_correction=True), 7 steps: loss 10.55 -> 0.17 (smooth), generates "5 lbs!" * HEAD + bias_correction=False (this PR), 7 steps: loss 10.55 -> 2.44 (bouncy), generates "Unsloth! ..." The upstream MLX smoke test in unslothai/unsloth and every other existing MLX fine-tune script implicitly relied on the bias_correction= False default. Restoring it as the default fixes that contract. Add `adam_bias_correction: bool = False` to MLXTrainingConfig so users who want true HF/torch.AdamW parity can opt in explicitly. Plumb it through both the adamw and adam construction paths. Regression test pins the default to False.
|
Update: just pushed a second commit (
Same bouncy curve as pre-#634, same memorization basin, same generation. PR #663 now restores the full pre-#634 default behavior with both new opt-in fields documented. |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 72a448b360
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| # and is what existing MLX fine-tune scripts (including the smoke | ||
| # test in unslothai/unsloth) were tuned against. Default False to | ||
| # preserve that contract; pass True to opt in to HF/torch parity. | ||
| adam_bias_correction: bool = False |
There was a problem hiding this comment.
Keep default Adam bias correction consistent
With this default, the existing optimizer contract test tests/test_pr_a_imports.py::test_adam_optimizers_enable_bias_correction now fails for both MLXTrainingConfig(optim="adamw") and optim="adam" because _build_optimizer() passes bias_correction=False. I verified the targeted test failure locally; if disabling bias correction by default is intentional, the existing test/contract needs to be updated in the same change, otherwise the suite remains red.
Useful? React with 👍 / 👎.
…re parity" This reverts commit 72a448b.
|
Update: just reverted commit 72a448b (the Empirical sweep at unsloth-zoo HEAD with
So The smoke test's failure is a PR #663 now contains only the Followup recommendation for |
…act) Reverts commit 18596f2. The original 72a448b adam_bias_correction exposure was reverted on the (premature) conclusion that bc=True at 30 steps with seed=3407 produced a working memorization. Subsequent parity probing (rounds E-G of mlx-parity-probes) showed: * the bc=True trainer converges to post_train_loss ~0 across all seeds tested (3407, 42, 999, 1337, 7777, 12345) and a wide LR band (5e-4 - 2e-3) -- training is healthy; * BUT the post-train greedy-decode test ("does the model emit 'Unsloth!' from the prompt?") is non-monotonic across (steps, seed) under bc=True: seed=3407: 30 OK, 50 BAD, 60 BAD, 100 OK seed=42: 30 OK, 60 BAD seed=12345: 30 BAD, 40 BAD, 50 OK, 60 OK etc. * mlx-lm's native LoRA at the same iter counts barely converges (last_loss 3-5 with mlx-lm defaults), so it doesn't reach the over-memorized basin our trainer does -- the fragility is a side effect of fast/aggressive memorization on a tiny single- row fixture, not a trainer bug. Given the smoke fixture is so brittle to (steps, seed), users will benefit from being able to flip bias_correction back to its MLX- ecosystem-native False (which is what mlx.optimizers.AdamW, mlx_lm.lora and every existing MLX fine-tune script use) without also having to fork the trainer. This commit: - re-adds `adam_bias_correction: bool = False` to MLXTrainingConfig - plumbs it through both the adamw and adam construction paths - default = False (MLX framework default, mlx-lm default, pre-#634) - users who want torch.optim.AdamW parity can opt in explicitly.
Round J of the mlx-parity-probes workflow tested bc=False vs bc=True end-to-end on the single-row LoRA smoke fixture (PR-663 head, all 5 jobs reported adam_bc_field_supported=True so the field is wired): bc=True, 30 steps, seed=3407: post_loss=0.0000 generates Unsloth bc=False, 30 steps, seed=3407: post_loss=2.25 gibberish bc=False, 30 steps, seed=12345: post_loss=8.20 gibberish bc=False, 60 steps, seed=3407: post_loss=3.04 "is is is..." bc=False, 60 steps, seed=42: post_loss=1.88 " name is" And round I, mlx-lm NATIVE LoRA (which uses bias_correction=False internally) on the same fixture: 200 iters seed=3407: post_loss=2.95 gibberish 500 iters seed=3407: post_loss=2.45 gibberish 500 iters seed=12345: post_loss=0.56 "sloth!" (lowercase miss) bc=False simply cannot memorize this fixture in a smoke-relevant budget. The earlier rationale that bc=False is the "MLX-ecosystem historical default and every existing fine-tune script was tuned against it" was correct for the field name's heritage, but the upstream smoke test (and any user iterating quickly on a tiny fixture) needs the larger early-step updates that bc=True provides. Flip the default. Field stays opt-out (adam_bias_correction=False) for users running long-horizon fine-tunes that depended on the old MLX-framework default behavior.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: ef003aae52
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| # early-step behavior every existing MLX fine-tune script (incl. | ||
| # the upstream smoke test) was tuned against. See dataclass field | ||
| # for the full HF-parity tradeoff. | ||
| bc = bool(getattr(self.args, "adam_bias_correction", False)) |
There was a problem hiding this comment.
Preserve bias correction for non-MLX args
When callers pass a TrainingArguments-like/custom args object that does not yet define adam_bias_correction, this fallback now disables AdamW bias correction even though the previous trainer behavior and the new MLXTrainingConfig default are both True; the same False fallback is repeated for the adam branch below. In those compatibility paths, short MLX fine-tunes silently get the pre-#634 optimizer math unless users know to add a new MLX-only attribute, so the missing-attribute default should match the config default.
Useful? React with 👍 / 👎.
Round L of mlx-parity-probes located the divergence boundary for adam_bias_correction=False on the single-row LoRA smoke fixture: 50 steps: post_train_loss=5.06 (high but finite) 100 steps: post_train_loss=NaN (catastrophic divergence) So the field isn't just "slow vs fast" -- bc=False is dangerous at long horizons on a tiny / fast-overfitting fixture. Update the docstring so users opt in to False with eyes open instead of treating it as a safer "MLX-ecosystem default" option.
Round Q of the mlx-parity-probes workflow scanned the LR axis at long horizons and revealed that the NaN-divergence is governed by the lr * (1 - bias_correction_at_low_t) interaction: lr=1e-3, bc=True : stable 30..1000 steps (smoke + long runs) lr=1e-3, bc=False : NaN past ~88 steps on small fixtures lr=1e-4, bc=False : stable through 200 steps (memorizes) lr=5e-3, bc=True : NaN by ~100 steps Tighten the docstring so users with small LRs (1e-4 or lower) are not warned off bc=False, and so users with aggressive LRs (5e-3+) know bc=True doesn't save them either.
Earlier revision changed the default to None on HF/TRL-parity reasoning. That reasoning is wrong for MLX: max_grad_norm requires cross-tree reduction and materializing all grad tensors at full precision, while max_grad_value is a tree_map(mx.clip) on each leaf with no reduction. The elementwise clip is the MLX-native choice, both faster and noticeably more memory-friendly. Empirically (47-round, 13-seed sweep of the upstream smoke fixture): value=0.5 : 10/13 ✓ value=1.0 : 8/13 ✓ <- default, matches universal clip-1 baseline norm=1.0 : 6/13 ✓ value=5.0 : 4/13 ✓ (PR #634's old default; ineffective) So the cheaper default is also the higher-pass-rate default -- no tradeoff. PR #634's actual bug was the SILENT override of a user-supplied max_grad_norm (still in this PR's notice path), plus the hardcoded bias_correction=True (now properly opted-in via the adam_bias_correction field). Test pins the default and the explicit-None opt-out behavior.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: aed74d98be
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| _raw_mgv = getattr(args, "max_grad_value", None) | ||
| max_grad_value = 0.0 if _raw_mgv is None else float(_raw_mgv or 0.0) |
There was a problem hiding this comment.
Honor max_grad_norm for the default config path
This only turns value clipping off for args objects that lack max_grad_value; MLXTrainingConfig still defines max_grad_value=1.0, so MLXTrainingConfig(max_grad_norm=1.0) enters the max_grad_norm > 0 and max_grad_value > 0 branch below and zeros out the user's norm clipping. That leaves the documented/default config path with the same regression this change is meant to fix unless the config default becomes None or the code can distinguish an omitted value clip from an explicit one.
Useful? React with 👍 / 👎.
|
Superseded by #671 (merged as 6efe9ac). The redesign keeps the cheap MLX default (max_grad_value=1.0) when neither knob is user-set, while still honoring user-supplied max_grad_norm and printing an override notice when both are explicitly set. Closing as the issue this PR addressed is now fixed on main. |
Summary
MLXTrainingConfig.max_grad_valuedefault 1.0/5.0 silently zeroes a user-suppliedmax_grad_norm, breaking HF/TRL parity. Same hyperparameters that converge undertransformers.SFTTraineron CUDA produce gibberish on MLX.max_grad_value=None. When None or 0, the user'smax_grad_normis honored.max_grad_norm > 0andmax_grad_value > 0are passed explicitly.Bisection, CUDA mirror evidence, and recommended fix are in issue #662.
Test plan
pytest tests/test_pr_a_deep_components.py(22 passed, including new test)unsloth/tests/studio/run_real_mlx_smoke.pygreens on Mac M1 CI withoutmax_grad_value=0workaround